{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QazSRJeXIbe2"
      },
      "source": [
        "# Train cross-lingual supervised text classifiers for elite criticism detection in political tweets using bag-of-word representations of machine-translated tweet texts\n",
        "\n",
        "*author:* Hauke Licht\n",
        "\n",
        "In this notebook, I train cross-lingual supervised text classifiers for elite criticism detection in political tweets using bag-of-word representations of machine-translated tweet texts.\n",
        "\n",
        "The labeled dataset I use records 5.3K+ tweets that have been sampled from tweets posted by political parties from 20 Western countries between 2008 and early 2021.\n",
        "The annotations come from 6 crowd coders per tweet that I have aggregated into tweet-level labels using a Dawid and Skene ([1979](https://doi.org/10.2307/2346806)) annotation model (cf. [Paun et al. 2018](https://aclanthology.org/Q18-1040.pdf)).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NVKKfzcPLA8T"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EQgctmDm6srd"
      },
      "source": [
        "## Set data paths"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "sB2uq4Jls4X4"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "base_path = os.path.join('..', '..')\n",
        "data_path = os.path.join(base_path, 'data')\n",
        "input_path = os.path.join(data_path, 'intermediate', 'training')\n",
        "res_dir = os.path.join(data_path, 'output', 'classifier_results')\n",
        "os.makedirs(res_dir, exist_ok = True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pejti7uxLECQ"
      },
      "source": [
        "## Load required packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "id": "YqjecWFs4iBN"
      },
      "outputs": [],
      "source": [
        "# for I/O\n",
        "import os\n",
        "import json\n",
        "\n",
        "# for data wrangling\n",
        "import pandas as pd\n",
        "import numpy as np\n",
        "\n",
        "# for pre-processing\n",
        "import nltk\n",
        "from nltk.tokenize import TweetTokenizer\n",
        "from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer\n",
        "\n",
        "# linear classifiers\n",
        "from sklearn.pipeline import Pipeline\n",
        "from sklearn.naive_bayes import MultinomialNB\n",
        "from sklearn.linear_model import SGDClassifier\n",
        "from sklearn.neural_network import MLPClassifier\n",
        "\n",
        "# for train/test data preparation\n",
        "from sklearn.model_selection import GridSearchCV\n",
        "\n",
        "# for evaluation\n",
        "from sklearn.metrics import classification_report#, precision_recall_fscore_support\n",
        "\n",
        "# misc\n",
        "import random"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3WJ5-UwuqwKq"
      },
      "source": [
        "\n",
        "## Set global configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "id": "lwb1BINhqwKq",
        "outputId": "58a3117b-f343-4885-b27e-3c80990e5eec"
      },
      "outputs": [],
      "source": [
        "# set the seed\n",
        "SEED = 1234\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XWqZ7LGMFeHV"
      },
      "source": [
        "# Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQWLAKg4aSeG"
      },
      "source": [
        "## Description"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BqBm2yyeYta4"
      },
      "source": [
        "The dataset we'll load has the following columns:\n",
        "\n",
        "- `item_id` (str): Unique ID of tweet (has been constructed by concatenating ISO-3-character country code of the party posting the tweet, `user_id`, and `status_id`)\n",
        "- `user_id` (int): the ID of the account that has posted the tweet\n",
        "- `status_id` (int): the ID of the tweet\n",
        "- `labeling` (str): the label class a tweet has been assigned to (i.e., its label)\n",
        "- `text` (str): The tweet's text (in its original language)\n",
        "- `text_en` (str): The tweet's machine-translated text (into English)\n",
        "- `test_` (bool): Boolean flag indicating tweets that should in the test (not the training) data split\n",
        "\n",
        "Note that user and status IDs are integers because they can be very long.\n",
        "Hence, I'll read them as int64 types to ensure they are not corrputed.\n",
        "(Alternatively, you could just treat them as strings.)\n",
        "\n",
        "To this end, I create a dictionary mapping column names to the desired data types that I'll pass when reading the CSV file:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "Wrj0Mxe2am0W"
      },
      "outputs": [],
      "source": [
        "col_types = {\n",
        "  'item_id': str,\n",
        "  'user_id': 'Int64',\n",
        "  'status_id': 'Int64',\n",
        "  'labeling': str,\n",
        "  'text': str,\n",
        "  'text_en': str,\n",
        "  'test_': bool\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rP65yYzRYLsb"
      },
      "source": [
        "## Download\n",
        "\n",
        "Download and read the labeled tweets dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 206
        },
        "id": "DsmH1rS_XKgr",
        "outputId": "2436cb64-ec56-4f2f-bf66-b446f705a77a"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>item_id</th>\n",
              "      <th>user_id</th>\n",
              "      <th>status_id</th>\n",
              "      <th>labeling</th>\n",
              "      <th>text</th>\n",
              "      <th>text_en</th>\n",
              "      <th>test_</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>IRL_19530527_19326417276</td>\n",
              "      <td>19530527</td>\n",
              "      <td>19326417276</td>\n",
              "      <td>yes-general</td>\n",
              "      <td>The Government has become the single biggest o...</td>\n",
              "      <td>The Government has become the single biggest o...</td>\n",
              "      <td>True</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>IRL_22628924_120756369459642370</td>\n",
              "      <td>22628924</td>\n",
              "      <td>120756369459642370</td>\n",
              "      <td>no</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>IRL_22628924_120756369459642370</td>\n",
              "      <td>22628924</td>\n",
              "      <td>120756369459642370</td>\n",
              "      <td>no</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>As president Martin McGuinness will use his in...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>ESP_83784273_1054674229541650440</td>\n",
              "      <td>83784273</td>\n",
              "      <td>1054674229541650440</td>\n",
              "      <td>no</td>\n",
              "      <td>El presidente del EBB, @andoniortuzar; la pres...</td>\n",
              "      <td>The president of EBB, @andoniortuzar; the pres...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>GRC_2808263274_1014889230466715653</td>\n",
              "      <td>2808263274</td>\n",
              "      <td>1014889230466715653</td>\n",
              "      <td>yes-specific</td>\n",
              "      <td>#BeLeventis Κατηγορείτε ως ακροδεξιούς, όσους ...</td>\n",
              "      <td>#BeLeventis Worship as far-right, those who do...</td>\n",
              "      <td>False</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                              item_id     user_id            status_id  \\\n",
              "0            IRL_19530527_19326417276    19530527          19326417276   \n",
              "1     IRL_22628924_120756369459642370    22628924   120756369459642370   \n",
              "2     IRL_22628924_120756369459642370    22628924   120756369459642370   \n",
              "3    ESP_83784273_1054674229541650440    83784273  1054674229541650440   \n",
              "4  GRC_2808263274_1014889230466715653  2808263274  1014889230466715653   \n",
              "\n",
              "       labeling                                               text  \\\n",
              "0   yes-general  The Government has become the single biggest o...   \n",
              "1            no  As president Martin McGuinness will use his in...   \n",
              "2            no  As president Martin McGuinness will use his in...   \n",
              "3            no  El presidente del EBB, @andoniortuzar; la pres...   \n",
              "4  yes-specific  #BeLeventis Κατηγορείτε ως ακροδεξιούς, όσους ...   \n",
              "\n",
              "                                             text_en  test_  \n",
              "0  The Government has become the single biggest o...   True  \n",
              "1  As president Martin McGuinness will use his in...  False  \n",
              "2  As president Martin McGuinness will use his in...  False  \n",
              "3  The president of EBB, @andoniortuzar; the pres...  False  \n",
              "4  #BeLeventis Worship as far-right, those who do...  False  "
            ]
          },
          "execution_count": 22,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "fp = os.path.join(input_path, 'training_data_pooled_samples.csv')\n",
        "dat = pd.read_csv(fp, sep = ',', dtype = col_types)\n",
        "dat.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "btNyCm4RYfiq"
      },
      "source": [
        "Set unique IDs as index."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "MW-9z7OLYPpL"
      },
      "outputs": [],
      "source": [
        "dat.set_index('item_id', inplace = True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n9kI6uU8cXP6"
      },
      "source": [
        "## Create binary labels\n",
        "\n",
        "Let's have a look at the `labeling` values:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "s6gLkIe8cbi0",
        "outputId": "1076c4fe-2e7e-4c98-a76d-ac5653cfac16"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "labeling\n",
              "no              3344\n",
              "yes-general     1289\n",
              "yes-specific     768\n",
              "Name: count, dtype: int64"
            ]
          },
          "execution_count": 24,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dat.labeling.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dw4127o6chWY"
      },
      "source": [
        "The labelings indicates whether a tweet contains\n",
        "\n",
        "1. **no** elite criticism,\n",
        "2. elite criticism directed at **the elite in general**, or\n",
        "3. criticism of **specific elites**.\n",
        "\n",
        "We argue that *the essence of anti-elite rhetoric* (as a political strategy) is generalized elite criticism.\n",
        "Hence, we are mainly interested in the distinction between 'general' elite criticism and all other statements.\n",
        "Accordingly, I create a **binary** label indicator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5Et8jtZGdPM7",
        "outputId": "cbaf5d9c-0337-4331-92d0-62e1ba149cac"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "label_\n",
              "0    4112\n",
              "1    1289\n",
              "Name: count, dtype: int64"
            ]
          },
          "execution_count": 25,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dat['label_'] = dat.labeling == 'yes-general' # positive (negative) class label => True (False)\n",
        "dat['label_'] = dat['label_'].astype(int)  # positive (negative) class label => 1 (0)\n",
        "dat.label_.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j7VUkxPeqzme"
      },
      "source": [
        "# Preparation for training and evaluation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t4zQ6tGSb7iN"
      },
      "source": [
        "I next split the dataset into the training and test partitions.\n",
        "To do so, I use the `test_` indicator column that comes with the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6fkdzC2S0xnS",
        "outputId": "06524dad-c71c-4693-bbfa-833ba24c4edb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "No. train samples: 4342; pos. label proportion: 0.236\n",
            "No. test samples:  1059; pos. label proportion: 0.250\n"
          ]
        }
      ],
      "source": [
        "train_dat = dat[~dat.test_]\n",
        "print(f'No. train samples: {len(train_dat)}; pos. label proportion: {train_dat.label_.values.mean():.3f}')\n",
        "\n",
        "test_dat = dat[dat.test_]\n",
        "print(f'No. test samples:  {len(test_dat)}; pos. label proportion: {test_dat.label_.values.mean():.3f}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "3wY3mRy5sp4v",
        "outputId": "f757005c-f06f-4acb-e0eb-3c7ffd18e3a9"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "label_\n",
              "0    3318\n",
              "1    1024\n",
              "Name: count, dtype: int64"
            ]
          },
          "execution_count": 27,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "train_dat.label_.value_counts()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VbKiq70P1dOP"
      },
      "source": [
        "Next, I load the JSON file that records which tweets are in which CV folds:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "id": "9O0I5IHb65Dr"
      },
      "outputs": [],
      "source": [
        "with open(os.path.join(input_path, 'cv_ids.json'), 'r') as f:\n",
        "  cv_folds = json.load(f)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WHr9kUm99Z2R"
      },
      "source": [
        "Let's veryify that the training data split can be subsetted by these IDs:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 29,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Ybg7Pv48qiuw",
        "outputId": "ecbcbf65-89c3-4544-8bfc-7a697d33de3d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "fold 0: # train = 3473, # test = 869\n",
            "fold 1: # train = 3473, # test = 869\n",
            "fold 2: # train = 3474, # test = 868\n",
            "fold 3: # train = 3474, # test = 868\n",
            "fold 4: # train = 3474, # test = 868\n"
          ]
        }
      ],
      "source": [
        "# verify CV splits\n",
        "for i, fold in cv_folds.items():\n",
        "  print(f\"fold {i}: # train = {len(fold['train'])}, # test = {len(fold['val'])}\")\n",
        "  # next two lines would throw error if sths wrong with the indeces\n",
        "  train_dat.loc[ fold['train'] ]\n",
        "  train_dat.loc[ fold['val'] ]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5W3dKOVU9esH"
      },
      "source": [
        "I reconstruct the data sets *indexes* from this information:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "4jxJvj7P3ue9"
      },
      "outputs": [],
      "source": [
        "cv_idxs = list()\n",
        "for fold in cv_folds.values():\n",
        "  train_idxs, val_idxs = list(), list()\n",
        "  for idx, id in enumerate(train_dat.index):\n",
        "    if id in fold['train']:\n",
        "      train_idxs.append(idx)\n",
        "    else:\n",
        "      val_idxs.append(idx)\n",
        "  cv_idxs.append( ( np.asarray(train_idxs), np.asarray(val_idxs) ) )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7vH7kBbC9imK"
      },
      "source": [
        "Finally, I define a dictionary for collecting the classifeir evaluation results:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "id": "AczVzB1Z4_Sf"
      },
      "outputs": [],
      "source": [
        "best_params = dict()\n",
        "cv_res = dict()\n",
        "performances = dict()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g2LE-XghzN5E"
      },
      "source": [
        "# Train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AtF212C0t_ju",
        "outputId": "cfbd66d5-03b5-478d-d820-98bd7e8d4c2c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 127 µs (started: 2023-10-19 17:31:22 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# track time\n",
        "%load_ext autotime"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dvoHMmthpFHl"
      },
      "source": [
        "Define the text vectorizer:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "8HyO3olvcHuF",
        "outputId": "4962763b-3a1b-4153-c0a4-bf831b612d85"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 440 µs (started: 2023-10-19 17:31:22 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# create a function for the tweet tokenizer from NLTK\n",
        "# source: https://predictivehacks.com/?all-tips=how-to-add-nltk-tokenizers-to-scikit-learn-tfidfvectorizer\n",
        "def tweet_tokenizer(text):\n",
        "    tt = TweetTokenizer()\n",
        "    return tt.tokenize(text)\n",
        "\n",
        "# create count vectorizer (i.e., creates DTM from corpus)\n",
        "vectorizer = CountVectorizer(\n",
        "  # stop words: ignore terms that have a document frequency strictly higher than the given threshold\n",
        "  max_df = .99,\n",
        "  # cut-off: ignore terms that have a document frequency strictly lower than the given threshold\n",
        "  min_df = 0.001,\n",
        "  max_features = 2500,\n",
        "  ngram_range = (1, 1),\n",
        "  dtype = 'int32',\n",
        "  analyzer = 'word',\n",
        "  tokenizer = tweet_tokenizer,\n",
        "  token_pattern=None,\n",
        "  strip_accents = 'unicode',\n",
        "  decode_error = 'replace'\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "MAdORJ85wMaN",
        "outputId": "ef69998d-032a-48ee-f17d-da5fc29b9e5f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 294 µs (started: 2023-10-19 17:31:22 +02:00)\n"
          ]
        }
      ],
      "source": [
        "nws = 0.5**np.asarray(range(1,5)) # [0.5, 0.25, 0.125, 0.0625]\n",
        "class_weights = [{0: nw, 1: 1-nw} for nw in nws]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "using 10 CPUs\n",
            "time: 309 µs (started: 2023-10-19 17:31:22 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# get the number of available CPUs\n",
        "import multiprocessing\n",
        "n_cpus = multiprocessing.cpu_count()\n",
        "print(f'using {n_cpus} CPUs')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-jTGChi-zQST"
      },
      "source": [
        "## Train Naive Bayes model\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "w8J7-6iHMKAO",
        "outputId": "0db77462-cfcb-47ec-d2ae-9155456e6544"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Fitting 5 folds for each of 16 candidates, totalling 80 fits\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 8.77 s (started: 2023-10-19 17:31:22 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# prepare pipeline components\n",
        "components = list()\n",
        "components += [('vect', vectorizer)]\n",
        "components += [('tfidf', TfidfTransformer())]\n",
        "\n",
        "# naive Bayes\n",
        "components += [('nb', MultinomialNB())]\n",
        "clf_nb = Pipeline(components)\n",
        "\n",
        "tuning_params = dict(\n",
        "  vect__ngram_range = [(1, 1), (1, 2)],\n",
        "  tfidf__use_idf = (True, False),\n",
        "  nb__class_prior = [list(ws.values()) for ws in class_weights]\n",
        ")\n",
        "\n",
        "# initialize grid searcher\n",
        "grid_search = GridSearchCV(\n",
        "  estimator = clf_nb,\n",
        "  param_grid = tuning_params,\n",
        "  cv = cv_idxs,\n",
        "  scoring = ['precision', 'recall', 'f1'],\n",
        "  refit = 'f1',\n",
        "  n_jobs = n_cpus,\n",
        "  verbose = 1\n",
        ")\n",
        "\n",
        "# train\n",
        "# note: this is where I pass the translated texts' and corresponding labels\n",
        "clf_nb = grid_search.fit(train_dat.text_en.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 37,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "A2SwoS0ArjyG",
        "outputId": "efd11198-0195-44e2-d82f-c59a3d330faf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 1.36 ms (started: 2023-10-19 17:31:31 +02:00)\n"
          ]
        }
      ],
      "source": [
        "cv_res['nb'] = pd.DataFrame(clf_nb.cv_results_)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 38,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EjaWkg_Qudni",
        "outputId": "52c8287b-314d-4d96-bcc5-ffb589edf084"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'nb__class_prior': [0.5, 0.5],\n",
              " 'tfidf__use_idf': True,\n",
              " 'vect__ngram_range': (1, 2)}"
            ]
          },
          "execution_count": 38,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 1.51 ms (started: 2023-10-19 17:31:31 +02:00)\n"
          ]
        }
      ],
      "source": [
        "best_params['nb'] = clf_nb.best_params_\n",
        "best_params['nb']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZQxo6mlnodnN"
      },
      "source": [
        "### Evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 39,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hgiIjICVodnO",
        "outputId": "be006312-e181-49ca-e269-e1cdf117de0b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 103 ms (started: 2023-10-19 17:31:31 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_nb.predict(test_dat.text_en.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 40,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eCWlpw7huprn",
        "outputId": "64b7930c-58f3-416a-80e2-63b4f42527a9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.79      0.75      0.77       794\n",
            "         pos       0.35      0.41      0.38       265\n",
            "\n",
            "    accuracy                           0.66      1059\n",
            "   macro avg       0.57      0.58      0.57      1059\n",
            "weighted avg       0.68      0.66      0.67      1059\n",
            "\n",
            "time: 5.46 ms (started: 2023-10-19 17:31:31 +02:00)\n"
          ]
        }
      ],
      "source": [
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'])\n",
        "print(res)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 224
        },
        "id": "dXmRbEmnodnP",
        "outputId": "44723839-d9e3-47f4-db3e-9c0d7085ca68"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.791223</td>\n",
              "      <td>0.749370</td>\n",
              "      <td>0.769728</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.351792</td>\n",
              "      <td>0.407547</td>\n",
              "      <td>0.377622</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.663834</td>\n",
              "      <td>0.663834</td>\n",
              "      <td>0.663834</td>\n",
              "      <td>0.663834</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.571507</td>\n",
              "      <td>0.578459</td>\n",
              "      <td>0.573675</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.681262</td>\n",
              "      <td>0.663834</td>\n",
              "      <td>0.671609</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.791223  0.749370  0.769728   794.000000\n",
              "pos            0.351792  0.407547  0.377622   265.000000\n",
              "accuracy       0.663834  0.663834  0.663834     0.663834\n",
              "macro avg      0.571507  0.578459  0.573675  1059.000000\n",
              "weighted avg   0.681262  0.663834  0.671609  1059.000000"
            ]
          },
          "execution_count": 41,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 11.7 ms (started: 2023-10-19 17:31:31 +02:00)\n"
          ]
        }
      ],
      "source": [
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances['nb'] = pd.DataFrame(res).transpose()\n",
        "performances['nb']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KyXGDqZ67IXg"
      },
      "source": [
        "## Train a L2-regularized linear classifier (Perceptron)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 42,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "V-vaxsrw7OsX",
        "outputId": "784a119b-bbc2-42e7-eee1-d2434ea58a9f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Fitting 5 folds for each of 64 candidates, totalling 320 fits\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 26.1 s (started: 2023-10-19 17:31:31 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# prepare pipeline components\n",
        "components = list()\n",
        "components += [('vect', vectorizer)]\n",
        "components += [('tfidf', TfidfTransformer())]\n",
        "components += [('perceptron', SGDClassifier(loss = 'hinge', penalty = 'l2', fit_intercept = False, random_state = SEED))]\n",
        "clf_per = Pipeline(components)\n",
        "\n",
        "# the bigger alpha, the stronger regularization (i.e., more coefficients are shrunk towards zero)\n",
        "tuning_params = dict(\n",
        "  vect__ngram_range = [(1, 1), (1, 2)],\n",
        "  tfidf__use_idf = (True, False),\n",
        "  perceptron__alpha = [1e-4, 1e-5, 1e-6, 1e-7],\n",
        "  perceptron__class_weight = class_weights\n",
        ")\n",
        "\n",
        "# initialize grid searcher\n",
        "grid_search = GridSearchCV(\n",
        "  estimator = clf_per,\n",
        "  param_grid = tuning_params,\n",
        "  cv = cv_idxs,\n",
        "  scoring = ['precision', 'recall', 'f1'],\n",
        "  refit = 'f1',\n",
        "  n_jobs = n_cpus,\n",
        "  verbose = 1\n",
        ")\n",
        "\n",
        "# train\n",
        "clf_per = grid_search.fit(train_dat.text_en.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "CSCKM_9vs7Sm",
        "outputId": "bc20b5aa-4df7-4358-878b-075fb83f321e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 1.45 ms (started: 2023-10-19 17:31:57 +02:00)\n"
          ]
        }
      ],
      "source": [
        "cv_res['per'] = pd.DataFrame(clf_per.cv_results_)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 44,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2PGSRkgts7Sm",
        "outputId": "e9fdc446-159e-4bbe-e711-2b51501967d0"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'perceptron__alpha': 0.0001,\n",
              " 'perceptron__class_weight': {0: 0.25, 1: 0.75},\n",
              " 'tfidf__use_idf': False,\n",
              " 'vect__ngram_range': (1, 1)}"
            ]
          },
          "execution_count": 44,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 1.47 ms (started: 2023-10-19 17:31:57 +02:00)\n"
          ]
        }
      ],
      "source": [
        "best_params['per'] = clf_per.best_params_\n",
        "best_params['per']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ssZe7api783z"
      },
      "source": [
        "### Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "7p-0zid3wY9y",
        "outputId": "97d225cd-e89f-4aa3-99c2-e0598f63d30c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 89.6 ms (started: 2023-10-19 17:31:57 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_per.predict(test_dat.text_en.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "W3_blZE3wdgY",
        "outputId": "0552d90e-6821-4a53-9d20-9968af34d4b9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.82      0.72      0.77       794\n",
            "         pos       0.39      0.54      0.45       265\n",
            "\n",
            "    accuracy                           0.67      1059\n",
            "   macro avg       0.61      0.63      0.61      1059\n",
            "weighted avg       0.72      0.67      0.69      1059\n",
            "\n",
            "time: 4.95 ms (started: 2023-10-19 17:31:57 +02:00)\n"
          ]
        }
      ],
      "source": [
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'])\n",
        "print(res)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 224
        },
        "id": "g8ldehdUwdgZ",
        "outputId": "7bdc0681-aaaa-4ff2-cf94-2fce66617745"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.824891</td>\n",
              "      <td>0.717884</td>\n",
              "      <td>0.767677</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.391304</td>\n",
              "      <td>0.543396</td>\n",
              "      <td>0.454976</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.674221</td>\n",
              "      <td>0.674221</td>\n",
              "      <td>0.674221</td>\n",
              "      <td>0.674221</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.608098</td>\n",
              "      <td>0.630640</td>\n",
              "      <td>0.611327</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.716392</td>\n",
              "      <td>0.674221</td>\n",
              "      <td>0.689428</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.824891  0.717884  0.767677   794.000000\n",
              "pos            0.391304  0.543396  0.454976   265.000000\n",
              "accuracy       0.674221  0.674221  0.674221     0.674221\n",
              "macro avg      0.608098  0.630640  0.611327  1059.000000\n",
              "weighted avg   0.716392  0.674221  0.689428  1059.000000"
            ]
          },
          "execution_count": 47,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 8.25 ms (started: 2023-10-19 17:31:57 +02:00)\n"
          ]
        }
      ],
      "source": [
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances['per'] = pd.DataFrame(res).transpose()\n",
        "performances['per']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mjk1G96n882J"
      },
      "source": [
        "## Train a Multi-layer Perceptron"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 48,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qOpCRZbI9APo",
        "outputId": "2260d6e4-6e8a-4741-bee3-0440077a21cd"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Fitting 5 folds for each of 9 candidates, totalling 45 fits\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 23.6 s (started: 2023-10-19 17:31:57 +02:00)\n"
          ]
        }
      ],
      "source": [
        "# prepare pipeline components\n",
        "components = list()\n",
        "components += [('vect', vectorizer)]\n",
        "components += [('tfidf', TfidfTransformer())]\n",
        "\n",
        "mlp_model = MLPClassifier(\n",
        "  activation = 'relu',\n",
        "  solver = 'adam',\n",
        "  hidden_layer_sizes = 100,\n",
        "  batch_size = 256,\n",
        "  random_state = SEED,\n",
        "  max_iter = 100,\n",
        "  learning_rate = 'adaptive',\n",
        "  early_stopping = True\n",
        ")\n",
        "\n",
        "components += [('mlp', mlp_model)]\n",
        "clf_mlp = Pipeline(components)\n",
        "\n",
        "\n",
        "tuning_params = dict(\n",
        "  mlp__batch_size = [64, 128, 256],\n",
        "  mlp__learning_rate_init = [1e-2, 1e-3, 1e-4]\n",
        ")\n",
        "\n",
        "\n",
        "# initialize grid searcher\n",
        "grid_search = GridSearchCV(\n",
        "  estimator = clf_mlp,\n",
        "  param_grid = tuning_params,\n",
        "  cv = cv_idxs,\n",
        "  scoring = ['precision', 'recall', 'f1'],\n",
        "  refit = 'f1',\n",
        "  n_jobs = n_cpus,\n",
        "  verbose = 1\n",
        ")\n",
        "\n",
        "# searchper\n",
        "clf_mlp = grid_search.fit(train_dat.text_en.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 49,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yEAhbMYZtbcS",
        "outputId": "06a6d160-77bc-409c-e148-9b1bf007c112"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'mlp__batch_size': 128, 'mlp__learning_rate_init': 0.01}"
            ]
          },
          "execution_count": 49,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 11.1 ms (started: 2023-10-19 17:32:21 +02:00)\n"
          ]
        }
      ],
      "source": [
        "cv_res['mlp'] = pd.DataFrame(clf_mlp.cv_results_)\n",
        "\n",
        "best_params['mlp'] = clf_mlp.best_params_\n",
        "best_params['mlp']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AR9r65HiDhof"
      },
      "source": [
        "### Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 50,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 608
        },
        "id": "DBCxv3KXDhof",
        "outputId": "3081080f-7337-41ae-cd6f-04def475d752"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.75      1.00      0.86       794\n",
            "         pos       0.00      0.00      0.00       265\n",
            "\n",
            "    accuracy                           0.75      1059\n",
            "   macro avg       0.37      0.50      0.43      1059\n",
            "weighted avg       0.56      0.75      0.64      1059\n",
            "\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n",
            "/Users/hlicht/miniforge3/envs/advanced_text_analysis_gesis_2023/lib/python3.10/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
            "  _warn_prf(average, modifier, msg_start, len(result))\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.749764</td>\n",
              "      <td>1.000000</td>\n",
              "      <td>0.856989</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.000000</td>\n",
              "      <td>0.000000</td>\n",
              "      <td>0.000000</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.749764</td>\n",
              "      <td>0.749764</td>\n",
              "      <td>0.749764</td>\n",
              "      <td>0.749764</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.374882</td>\n",
              "      <td>0.500000</td>\n",
              "      <td>0.428494</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.562146</td>\n",
              "      <td>0.749764</td>\n",
              "      <td>0.642539</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.749764  1.000000  0.856989   794.000000\n",
              "pos            0.000000  0.000000  0.000000   265.000000\n",
              "accuracy       0.749764  0.749764  0.749764     0.749764\n",
              "macro avg      0.374882  0.500000  0.428494  1059.000000\n",
              "weighted avg   0.562146  0.749764  0.642539  1059.000000"
            ]
          },
          "execution_count": 50,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 184 ms (started: 2023-10-19 17:32:21 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_mlp.predict(test_dat.text_en.values)\n",
        "\n",
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'])\n",
        "print(res)\n",
        "\n",
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances['mlp'] = pd.DataFrame(res).transpose()\n",
        "performances['mlp']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ShOiWcxzzSp4"
      },
      "source": [
        "## Train linear SVM\n",
        "\n",
        "Note: takes relatively long to cross-validate it"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 51,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1SDxjWL1VsVL",
        "outputId": "fe75a437-b76b-4619-adc5-97cadffc3f98"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Fitting 5 folds for each of 16 candidates, totalling 80 fits\n",
            "time: 22.8 s (started: 2023-10-19 17:32:21 +02:00)\n"
          ]
        }
      ],
      "source": [
        "from sklearn.svm import SVC\n",
        "\n",
        "# prepare pipeline components\n",
        "components = list()\n",
        "components += [('vect', vectorizer)]\n",
        "components += [('tfidf', TfidfTransformer())]\n",
        "components += [('svm', SVC(kernel = 'linear', random_state=SEED))]\n",
        "clf_svm = Pipeline(components)\n",
        "\n",
        "tuning_params = {\n",
        "  'vect__ngram_range': [(1, 1), (1, 2)],\n",
        "  'tfidf__use_idf': (True, False),\n",
        "  'svm__C': [1, 2, 4, 8]\n",
        "}\n",
        "\n",
        "# initialize grid searcher\n",
        "grid_search = GridSearchCV(\n",
        "  estimator = clf_svm,\n",
        "  param_grid = tuning_params,\n",
        "  cv = cv_idxs,\n",
        "  scoring = ['precision', 'recall', 'f1'],\n",
        "  refit = 'f1',\n",
        "  n_jobs = n_cpus,\n",
        "  verbose = 1\n",
        ")\n",
        "\n",
        "# train\n",
        "# note: this is where I pass the translated texts' and corresponding labels\n",
        "clf_svm = grid_search.fit(train_dat.text_en.values, train_dat.label_.values)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 52,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "POzVo-c_uIVR",
        "outputId": "b0689be2-e244-4ef5-efe5-48cb58036e56"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'svm__C': 8, 'tfidf__use_idf': True, 'vect__ngram_range': (1, 2)}"
            ]
          },
          "execution_count": 52,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 2.4 ms (started: 2023-10-19 17:32:44 +02:00)\n"
          ]
        }
      ],
      "source": [
        "cv_res['svm'] = pd.DataFrame(clf_svm.cv_results_)\n",
        "\n",
        "best_params['svm'] = clf_svm.best_params_\n",
        "best_params['svm']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f6QbEwztzc-J"
      },
      "source": [
        "### Evaluate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 53,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 380
        },
        "id": "FSrI2Jgf_nwd",
        "outputId": "45947b22-5042-411d-c7e3-936a4267db8d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "              precision    recall  f1-score   support\n",
            "\n",
            "         neg       0.79      0.83      0.81       794\n",
            "         pos       0.39      0.34      0.36       265\n",
            "\n",
            "    accuracy                           0.70      1059\n",
            "   macro avg       0.59      0.58      0.59      1059\n",
            "weighted avg       0.69      0.70      0.70      1059\n",
            "\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>precision</th>\n",
              "      <th>recall</th>\n",
              "      <th>f1-score</th>\n",
              "      <th>support</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>neg</th>\n",
              "      <td>0.788715</td>\n",
              "      <td>0.827456</td>\n",
              "      <td>0.807621</td>\n",
              "      <td>794.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>pos</th>\n",
              "      <td>0.393805</td>\n",
              "      <td>0.335849</td>\n",
              "      <td>0.362525</td>\n",
              "      <td>265.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>accuracy</th>\n",
              "      <td>0.704438</td>\n",
              "      <td>0.704438</td>\n",
              "      <td>0.704438</td>\n",
              "      <td>0.704438</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>macro avg</th>\n",
              "      <td>0.591260</td>\n",
              "      <td>0.581652</td>\n",
              "      <td>0.585073</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>weighted avg</th>\n",
              "      <td>0.689895</td>\n",
              "      <td>0.704438</td>\n",
              "      <td>0.696242</td>\n",
              "      <td>1059.000000</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "              precision    recall  f1-score      support\n",
              "neg            0.788715  0.827456  0.807621   794.000000\n",
              "pos            0.393805  0.335849  0.362525   265.000000\n",
              "accuracy       0.704438  0.704438  0.704438     0.704438\n",
              "macro avg      0.591260  0.581652  0.585073  1059.000000\n",
              "weighted avg   0.689895  0.704438  0.696242  1059.000000"
            ]
          },
          "execution_count": 53,
          "metadata": {},
          "output_type": "execute_result"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 409 ms (started: 2023-10-19 17:32:44 +02:00)\n"
          ]
        }
      ],
      "source": [
        "preds = clf_svm.predict(test_dat.text_en.values)\n",
        "\n",
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'])\n",
        "print(res)\n",
        "\n",
        "res = classification_report(test_dat.label_.values, preds, labels = [0, 1], target_names = ['neg', 'pos'], output_dict = True)\n",
        "performances['svm'] = pd.DataFrame(res).transpose()\n",
        "performances['svm']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lSMT5eVP8MHJ"
      },
      "source": [
        "# Write all to disk"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 54,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iPBe19XT8PNI",
        "outputId": "c00e7b74-840b-4fcd-dec9-d834eca1f2e4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 11.5 ms (started: 2023-10-19 17:32:44 +02:00)\n"
          ]
        }
      ],
      "source": [
        "fp = os.path.join(res_dir, 'bow+mt_cv_results.tab')\n",
        "pd.concat(cv_res, names = ['model', 'param_set']).to_csv(fp, sep = '\\t')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 55,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0sSELCmt9bvQ",
        "outputId": "be4e4f05-5a76-44f1-a63a-5393ef0b9c6c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 2.23 ms (started: 2023-10-19 17:32:44 +02:00)\n"
          ]
        }
      ],
      "source": [
        "fp = os.path.join(res_dir, 'bow+mt_best_params.json')\n",
        "with open(fp, 'w') as f:\n",
        "  json.dump(best_params, f)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 56,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RMlEh0yz82-f",
        "outputId": "99977e72-fa6d-4634-f8bb-395727d76ab7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "time: 2.47 ms (started: 2023-10-19 17:32:44 +02:00)\n"
          ]
        }
      ],
      "source": [
        "fp = os.path.join(res_dir, 'bow+mt_test_results.tab')\n",
        "pd.concat(performances, names = ['model', 'what']).to_csv(fp, sep = '\\t')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
